from onnx_helper import ONNXClassifierWrapper
import numpy as np
from PyCmpltrtok.common import *

trt_path = '/home/asuspei/PycharmProjects/AsusCondaP37Torch1101Cuda111/python_ai/category/onnx/torch2onnx/_save/vgg16_torch2onnx.py/v1.0/2022_06_17_12_54_40_476426-64.trt'
BATCH_SIZE = 64
N_CLASSES = 10
PRECISION = np.float32
trt_model = ONNXClassifierWrapper(trt_path, [BATCH_SIZE, N_CLASSES], target_dtype=PRECISION)

BATCH_SIZE = 32
dummy_input_batch = np.zeros((BATCH_SIZE, 224, 224, 3), dtype=PRECISION)

predictions = trt_model.predict(dummy_input_batch)
check_np(predictions, 'predictions')
